// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3
#pragma once
#include <cmath>

/** @file transform.h
	time/frequency transform functions
	*/
// include files
#include <cstddef>
#include <vector>

namespace math
{
/** real type */
using real = double;

/** array of reals */
using Vector = std::vector<real>;

/** 2D array of reals */
using Vector2D = std::vector<std::vector<real>>;

/** Time/frequency transform a vector using decimate-in-time fast Hartley method.
	@param d is the vector to be transformed in place (length a power of 2)
	*/
void hartley(Vector &d);

/** Time/frequency transform a two dimensional vector using decimate-in-time fast Hartley method.
	@param d is the vector to be transformed in place (length a power of 2)
	*/
void hartley(Vector2D &d);

/** Auto-correlate a vector.
	@param d is the domain data to be auto-correlated (length a power of 2)
	*/
void autoCorrelate(Vector &d);

/** Convolve vectors.
	@param d is the domain data to be convolved (length a power of 2)
	@param f is the filter data (length an odd number)
	*/
void convolve(Vector &d, const Vector &f);

/** Correlate matrices.
	@param d is the domain data to be cross-correlated (length a power of 2)
	@param f is the filter data (length an odd number)
	*/
void correlate(Vector2D &d, const Vector2D &f);

/** Convolve matrices.
	@param d is the domain data to be convolved (length a power of 2)
	@param f is the filter data (length an odd number)
	*/
void convolve(Vector2D &d, const Vector2D &f);

/** Hann window a vector
	@param d is the vector to be windowed in place
	*/
void hann(Vector &d);

/** Haar wavelet forward transform
	@param d is the vector to be transformed (length a power of 2)
	*/
void haarTransform0(Vector &d);

/** Haar wavelet inverse transform
	@param d is the vector to be transformed (length a power of 2)
	*/
void haarTransform1(Vector &d);

/** Daubechies D4 wavelet forward transform
	@param d is the vector to be transformed in place (length a power of 2)
	@param approx return approximation instead of difference
	*/
void d4Wavelet0(Vector &d, bool approx = false);

/** Daubechies D4 wavelet inverse transform
	@param d is the vector to be transformed in place (length a power of 2)
	*/
	void d4Wavelet1(Vector &d);


/** continuous Daubechies D4 wavelet transform
@param maxScale maximum scale to be returned
*/
template<int maxScale> class D4WaveletFilter
{
public:
	/** construct a transform
	*/
	D4WaveletFilter();
public:
	/** add a new data point
	@param x value
	@result value after transform
	*/
	real operator()(real x);

	/** get the approximation at the current scale */
	real getApproximation() const {return l0[iScale];}

	/** get the scale of the current value */
	int scale() const {return iScale;}

	/** get the filter lag */
	int lag() const {return (3<<iScale+1)-3;}
private:
	int i{0};
	int iScale;
	real l0[maxScale], l1[maxScale], h1[maxScale];
};

// initialization
template<int maxScale> D4WaveletFilter<maxScale>::D4WaveletFilter()
{
	for (iScale = 0; iScale < maxScale; ++iScale)
		l0[iScale] = l1[iScale] = h1[iScale] = 0;
}

// continuous Daubechies D4 wavelet transform
template<int maxScale> real D4WaveletFilter<maxScale>::operator()(real x)
{
	// at each scale
	int j = i++;
	for (iScale = 0; iScale < maxScale; ++iScale)
	{
		// if odd
		if (j & 1)
		{
			// add in new data
			h1[iScale] = x;

			// lifting steps
			l1[iScale] += sqrt(3.F)*h1[iScale];
			h1[iScale] -= (sqrt(3.F)/4.F)*l1[iScale] + ((sqrt(3.F)-2.F)/4.F)*l0[iScale];
			l0[iScale] -= h1[iScale];
			l0[iScale] *= (sqrt(3.F)-1.F)/2.F;

			// return difference
			return h1[iScale] * (sqrt(3.F)+1.F)/sqrt(2.F);
		}

		// else cascade to next scale with even data
		real x0 = l0[iScale];
		l0[iScale] = l1[iScale];
		l1[iScale] = x;
		x = x0;
		j >>= 1;
	}
	return x;
}

/** continuous LeGall 5/3 wavelet transform
@param maxScale maximum scale to be returned
*/
template<int maxScale> class LegallWaveletFilter
{
public:
	/** construct a transform
	*/
	LegallWaveletFilter();
public:
	/** add a new data point
	@param v value
	@result value after transform
	*/
	real operator()(real v);

	/** get the approximation at the current scale */
	real getApproximation() const {return l0[iScale];}

	/** get the difference at the current scale */
	real getDifference() const {return h0[iScale];}

	/** get the scale of the current value */
	int scale() const {return iScale;}

	/** get the filter lag */
	int lag() const {return (3<<iScale+1)-3-(1<<iScale);}
private:
	int i{2};
	int iScale;
	real x0[maxScale], x1[maxScale], h0[maxScale], l0[maxScale];
};

// initialization
template<int maxScale> LegallWaveletFilter<maxScale>::LegallWaveletFilter()
{
	for (iScale = 0; iScale < maxScale; ++iScale)
		x0[iScale] = x1[iScale] = h0[iScale] = l0[iScale] = 0;
}

// continuous LeGall 5/3 wavelet transform
template<int maxScale> real LegallWaveletFilter<maxScale>::operator()(real x2)
{
	// at each scale
	int j = ++i;
	for (iScale = 0; iScale < maxScale; ++iScale)
	{
		// if odd
		if (j & 1)
		{
			// lifting steps
			real h1 = x1[iScale]-(x0[iScale]+x2)/2;
			l0[iScale] = x0[iScale]+(h0[iScale]+h1)/4;
			h0[iScale] = h1;
			x0[iScale] = x2;

			// return difference
			return h0[iScale];
		}

		// else cascade to next scale with even data
		x1[iScale] = x2;
		x2 = l0[iScale];
		j >>= 1;
	}
	return x2;
}
}
